import os

from typing import Callable, Iterable, Optional, Tuple
from PIL import Image
from tqdm import tqdm

from super_gradients.common.object_names import Datasets, Processings
from super_gradients.common.registry.registry import register_dataset
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.module_interfaces import HasPreprocessingParams
from super_gradients.training.datasets.sg_dataset import DirectoryDataSet, ListDataset
from super_gradients.training.samples import SegmentationSample


@register_dataset(Datasets.SEGMENTATION_DATASET)
class SegmentationDataSet(DirectoryDataSet, ListDataset, HasPreprocessingParams):
    @resolve_param("transforms", factory=TransformsFactory())
    def __init__(
        self,
        root: str,
        list_file: str = None,
        samples_sub_directory: str = None,
        targets_sub_directory: str = None,
        cache_labels: bool = False,
        cache_images: bool = False,
        collate_fn: Callable = None,
        target_extension: str = ".png",
        transforms: Iterable = None,
    ):
        """
        SegmentationDataSet
            :param root:                        Root folder of the Data Set
            :param list_file:                   Path to the file with the samples list
            :param samples_sub_directory:       name of the samples sub-directory
            :param targets_sub_directory:       name of the targets sub-directory
            :param cache_labels:                "Caches" the labels -> Pre-Loads to memory as a list
            :param cache_images:                "Caches" the images -> Pre-Loads to memory as a list
            :param collate_fn:                  collate_fn func to process batches for the Data Loader
            :param target_extension:            file extension of the targets (default is .png for PASCAL VOC 2012)
            :param transforms:                  transforms to be applied on image and mask

        """
        self.samples_sub_directory = samples_sub_directory
        self.targets_sub_directory = targets_sub_directory
        self.cache_labels = cache_labels
        self.cache_images = cache_images

        # CREATE A DIRECTORY DATASET OR A LIST DATASET BASED ON THE list_file INPUT VARIABLE
        if list_file is not None:
            ListDataset.__init__(
                self,
                root=root,
                file=list_file,
                target_extension=target_extension,
                sample_loader=self.sample_loader,
                target_loader=self.target_loader,
                collate_fn=collate_fn,
            )
        else:
            DirectoryDataSet.__init__(
                self,
                root=root,
                samples_sub_directory=samples_sub_directory,
                targets_sub_directory=targets_sub_directory,
                target_extension=target_extension,
                sample_loader=self.sample_loader,
                target_loader=self.target_loader,
                collate_fn=collate_fn,
            )

        self.transforms = transforms if transforms else []

    def __getitem__(self, index):
        sample_path, target_path = self.samples_targets_tuples_list[index]

        # TRY TO LOAD THE CACHED IMAGE FIRST
        if self.cache_images:
            sample = self.imgs[index]
        else:
            sample = self.sample_loader(sample_path)

        # TRY TO LOAD THE CACHED LABEL FIRST
        if self.cache_labels:
            target = self.labels[index]
        else:
            target = self.target_loader(target_path)

        # MAKE SURE THE TRANSFORM WORKS ON BOTH IMAGE AND MASK TO ALIGN THE AUGMENTATIONS
        sample, target = self._transform_image_and_mask(sample, target)
        return sample, target

    @staticmethod
    def sample_loader(sample_path: str) -> Image:
        """
        sample_loader - Loads a dataset image from path using PIL
            :param sample_path: The path to the sample image
            :return:            The loaded Image
        """
        image = Image.open(sample_path).convert("RGB")
        return image

    @staticmethod
    def target_loader(target_path: str) -> Image:
        """
        target_loader
            :param target_path: The path to the sample image
            :return:            The loaded Image
        """
        target = Image.open(target_path)
        return target

    def _generate_samples_and_targets(self):
        """
        _generate_samples_and_targets
        """
        # IF THE DERIVED CLASS DID NOT IMPLEMENT AN EXPLICIT _generate_samples_and_targets CHILD METHOD
        if not self.samples_targets_tuples_list:
            super()._generate_samples_and_targets()

        # EXTRACT THE LABELS FROM THE TUPLES LIST
        image_files, label_files = map(list, zip(*self.samples_targets_tuples_list))
        image_indices_to_remove = []

        # CACHE IMAGES INTO MEMORY FOR FASTER TRAINING (WARNING: LARGE DATASETS MAY EXCEED SYSTEM RAM)
        if self.cache_images:
            # CREATE AN EMPTY LIST FOR THE LABELS
            self.imgs = len(self) * [None]
            cached_images_mem_in_gb = 0.0
            with tqdm(image_files, desc="Caching images") as pbar:
                for i, img_path in enumerate(pbar):
                    img = self.sample_loader(img_path)
                    if img is None:
                        image_indices_to_remove.append(i)

                    cached_images_mem_in_gb += os.path.getsize(image_files[i]) / 1024.0**3.0

                    self.imgs[i] = img
                    pbar.desc = "Caching images (%.1fGB)" % (cached_images_mem_in_gb)
            self.img_files = [e for i, e in enumerate(image_files) if i not in image_indices_to_remove]
            self.imgs = [e for i, e in enumerate(self.imgs) if i not in image_indices_to_remove]

        # CACHE LABELS INTO MEMORY FOR FASTER TRAINING - RELEVANT FOR EFFICIENT VALIDATION RUNS DURING TRAINING
        if self.cache_labels:
            # CREATE AN EMPTY LIST FOR THE LABELS
            self.labels = len(self) * [None]
            with tqdm(label_files, desc="Caching labels") as pbar:
                missing_labels, found_labels, duplicate_labels = 0, 0, 0

                for i, file in enumerate(pbar):
                    labels = self.target_loader(file)

                    if labels is None:
                        missing_labels += 1
                        image_indices_to_remove.append(i)
                        continue

                    self.labels[i] = labels
                    found_labels += 1

                    pbar.desc = "Caching labels (%g found, %g missing, %g duplicate, for %g images)" % (
                        found_labels,
                        missing_labels,
                        duplicate_labels,
                        len(image_files),
                    )
            assert found_labels > 0, "No labels found."

            #  REMOVE THE IRRELEVANT ENTRIES FROM THE DATA
            self.label_files = [e for i, e in enumerate(label_files) if i not in image_indices_to_remove]
            self.labels = [e for i, e in enumerate(self.labels) if i not in image_indices_to_remove]

    def _transform_image_and_mask(self, image, mask) -> tuple:
        """
        :param image:           The input image
        :param mask:            The input mask
        :return:                The transformed image, mask
        """
        sample = SegmentationSample(image=image, mask=mask)
        for t in self.transforms:
            sample = t.apply_to_sample(sample)
        return sample.image, sample.mask

    @property
    def _original_dataset_image_shape(self) -> Optional[Tuple[int, int]]:
        """
        Image default shape - (H,W)
        Default shape (model's input) should be defined for additional processing that might be needed
        when using "predict" any input-image/s can be used, the images should be rescaled to match the model's training-data shape
        """
        return None

    def get_dataset_preprocessing_params(self):
        """
        Return any hardcoded preprocessing + adaptation for PIL.Image image reading (RGB).
         image_processor as returned as a list of dicts to be resolved by processing factory.
        :return:
        """
        pipeline = []

        if self._original_dataset_image_shape:
            pipeline += [{Processings.SegmentationResizeWithPadding: {"output_shape": self._original_dataset_image_shape, "pad_value": 0}}]
            # Resize image to same image-shape as model input. default shape should be defined in dataset class under "output_image_shape"

        for t in self.transforms:
            pipeline += t.get_equivalent_preprocessing()
        params = dict(class_names=self.classes, image_processor={Processings.ComposeProcessing: {"processings": pipeline}})
        return params
